# encoding:utf-8
import psycopg2
import pandas as pd
import argparse
import io
import pymysql
import json
import numpy as np


def save_gp(file_path, table_name, index_col):
    print(file_path)
    if index_col is None:
        index_col = '_record_id_'
    df = pd.read_csv(file_path, error_bad_lines=False)
    row_num = df.shape[0]
    conn = psycopg2.connect(dbname=args.gp_db,
                            user=args.gp_user,
                            password=args.gp_password,
                            host=args.gp_host,
                            port=args.gp_port)
    cursor = conn.cursor()

    try:
        sql = "drop table if exists {}".format(table_name)
        cursor.execute(sql)
        df_columns = list(df.columns)
        # new_ri = True

        col_to_del = []
        for i in range(len(df_columns)):
            if "Unnamed: 0" in df_columns[i]:
                col_to_del.append(df_columns[i])

        if "_record_id_" not in df_columns:
            df_columns = list(df.columns)
            df_columns.insert(0, "_record_id_")
            df = df.reindex(columns=df_columns)
            df['_record_id_'] = np.arange(1, row_num + 1)

        for col in col_to_del:
            del df[col]
            df_columns = list(df.columns)

        dtypes = []
        for dtype in df.dtypes:
            if dtype == float:
                dtypes.append("decimal")
            elif dtype == int:
                dtypes.append("integer")
            else:
                dtypes.append("varchar")

        print("CHECKTHIS" + str(list(df.columns)))
        columns = ["\""+column+"\"" + " " + dtype for column, dtype in zip(df.columns, dtypes)]
        #columns = [column + " " + dtype for column, dtype in zip(df_columns, dtypes)]
        sql = "create table {}({})".format(table_name, ",".join(columns))
        cursor.execute(sql)
        data_io = io.StringIO()
        df.to_csv(data_io, sep="|", index=False)
        data_io.seek(0)
        # data_io.readline()  # remove header DO NOT DELETE THIS COMMENT
        copy_cmd = "COPY %s FROM STDIN HEADER DELIMITER '|' CSV" % table_name
        cursor.copy_expert(copy_cmd, data_io)

        sql = f"create index {table_name.split('.')[1]}_index on {table_name}({index_col})"
        cursor.execute(sql)
    except Exception as e:
        print(e)
    conn.commit()
    cursor.close()
    conn.close()
    return df_columns, dtypes, row_num

def get_mysql_cursor():
    return pymysql.connect(db=args.mysql_db,
                           user=args.mysql_user,
                           password=args.mysql_password,
                           host=args.mysql_host,
                           port=args.mysql_port)

def read_mysql(sql):
    conn = get_mysql_cursor()
    cursor = conn.cursor()
    cursor.execute(sql)
    meta = cursor.fetchall()
    columns = [desc[0] for desc in cursor.description]
    df = pd.DataFrame(meta, columns=columns)
    cursor.close()
    conn.close()
    return df


def get_task_instance_data_json(task_instance_id):
    series = read_mysql("select * from task_instance where id = '{}'".format(task_instance_id)).iloc[-1]
    data_json = json.loads(series["data_json"])
    return data_json


def execute_mysql(sql):
    conn = get_mysql_cursor()
    cursor = conn.cursor()
    cursor.execute(sql)
    conn.commit()
    cursor.close()
    conn.close()


def py_to_java(s):
    return s.replace("'", '"') \
        .replace("False", "false") \
        .replace("True", "true") \
        .replace(", ", ",") \
        .replace(": ", ":") \
        .replace("None", "null") \
        .replace("nan", "null") \
        # .replace("inf", "2147483647")


def update_task_instance(data_json, task_instance_id):
    print(data_json)
    data = py_to_java(str(data_json))
    data = "'" + data + "'"
    sql = "update task_instance set data_json={} where id={} ".format(data, task_instance_id)
    execute_mysql(sql)


def update_data_json(task_instance_id, df_columns, dtypes, table_name, row_num):
    data_json = get_task_instance_data_json(task_instance_id)
    semantics = {}
    for col in (df_columns):
        semantics[col] = "null"
    data_json_output = {
        "tableCols": df_columns
        , "columnTypes": dtypes
        , "tableName": table_name
        , "semantic": semantics
        , "totalRow": row_num
    }

    data_json["inputInfo"]["output"] = [data_json_output]

    update_task_instance(data_json, task_instance_id)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", type=str, required=True)
    parser.add_argument("--table_name", type=str, required=True)
    parser.add_argument("--gp_db",type=str, required=True)
    parser.add_argument("--gp_user",type=str, required=True)
    parser.add_argument("--gp_password",type=str, required=True)
    parser.add_argument("--gp_host",type=str, required=True)
    parser.add_argument("--gp_port",type=int, required=True)

    args = parser.parse_args()

    df_columns, dtypes, row_num = save_gp(args.file_path, args.table_name, None)

